{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Commonsense Path Generator for Connecting Entities\n",
    "In this notebook, we show how to use our proposed path generator to generate a commonsense relational path\n",
    "for connecting a pair of entities. You can then use the generator as a plug-in module for providing structured\n",
    "evidence to any downstream task."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import transformers\n",
    "assert transformers.__version__ == '2.8.0'\n",
    "from transformers import GPT2Config, GPT2Tokenizer, GPT2Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the generator model\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, gpt, config, max_len=31):\n",
    "        super(Generator, self).__init__()\n",
    "        self.gpt = gpt\n",
    "        self.config = config\n",
    "        self.max_len = max_len\n",
    "        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # input: [batch, seq]\n",
    "        context_len = inputs.size(1)\n",
    "        generated = inputs\n",
    "        next_token = inputs\n",
    "        past = None\n",
    "        with torch.no_grad():\n",
    "            for step in range(self.max_len):\n",
    "                outputs = self.gpt(next_token, past=past)\n",
    "                hidden = outputs[0][:, -1]\n",
    "                past = outputs[1]\n",
    "                next_token_logits = self.lm_head(hidden)\n",
    "                next_logits, next_token = next_token_logits.topk(k=1, dim=1)\n",
    "                generated = torch.cat((generated, next_token), dim=1)\n",
    "        return generated "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download a well-trained generator from this link to your local workspace:\n",
    "https://drive.google.com/file/d/1dQNxyiP4g4pdFQD6EPMQdzNow9sQevqD/view?usp=sharing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm_type = 'gpt2'\n",
    "config = GPT2Config.from_pretrained(lm_type)\n",
    "tokenizer = GPT2Tokenizer.from_pretrained(lm_type)\n",
    "tokenizer.add_tokens(['<PAD>'])\n",
    "tokenizer.add_tokens(['<SEP>'])\n",
    "tokenizer.add_tokens(['<END>'])\n",
    "gpt = GPT2Model.from_pretrained(lm_type)\n",
    "config.vocab_size = len(tokenizer)\n",
    "gpt.resize_token_embeddings(len(tokenizer))\n",
    "pretrain_generator_ckpt = \"/your_path_to_the_download_checkpoint/commonsense-path-generator.ckpt\"\n",
    "generator = Generator(gpt, config)\n",
    "generator.load_state_dict(torch.load(pretrain_generator_ckpt, map_location='cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_input(head_entity, tail_entity, input_len=16):\n",
    "    head_entity = head_entity.replace('_', ' ')\n",
    "    tail_entity = tail_entity.replace('_', ' ')\n",
    "    input_token = tail_entity + '<SEP>' + head_entity\n",
    "    input_id = tokenizer.encode(input_token, add_special_tokens=False)[:input_len]\n",
    "    input_id += [tokenizer.convert_tokens_to_ids('<PAD>')] * (input_len - len(input_id))\n",
    "    return torch.tensor([input_id], dtype=torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def connect_entities(head_entity, tail_entity):\n",
    "    gen_input = prepare_input(head_entity, tail_entity)\n",
    "    gen_output = generator(gen_input)\n",
    "    path = tokenizer.decode(gen_output[0].tolist(), skip_special_tokens=True)\n",
    "    path = ' '.join(path.replace('<PAD>', '').split())\n",
    "    return path[path.index('<SEP>')+6:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Usage Example\n",
    "- Input: A pair of entities you want to connect, expressed in natural language.\n",
    "- Output: A relational path in the form of (head_entiy, relation1, intermedia_entity1, relation2, ..., tail_entity)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "curiosity causesdesire find information hassubevent read _hasprerequisite hear news\n"
     ]
    }
   ],
   "source": [
    "head_entity = 'curiosity'\n",
    "tail_entity = 'hear_news'\n",
    "path = connect_entities(head_entity, tail_entity)\n",
    "print(path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
